import torch
from torch import Tensor

from src.envs.base_environment import ContinuousEnvironment  



class HypergridEnvironment(ContinuousEnvironment):
    """
    ### Description
    
    The hypergrid environment is a ND environment where the state is a ND vector in the x-y plane. 
    The action is a ND vector representing the change in the (x1, x2, ..., xN) coordinates. 
    The goal is to sample a hypergrid reward distribution with high reward peaks at the corners of a hypercube (edge size 10). 

    ### Policy Parameterisation

    The policy is parameterised as a mixture model with `mixture_dim` components.
    The mixture is a Nvariate mixture of Gaussians.

    ### Arguments

    - `max_policy_std`: Maximum sigma parameter for the Gaussian distribution.
    - `min_policy_std`: Minimum sigma parameter for the Gaussian distribution.
    - `num_grid_points`: Number of grid points in each dimension of the state space.
    - `mixture_dim`: Number of components in the Gaussian mixture model in the parameterisation of the policy.
    """
    
    def __init__(
            self, 
            config):
        self._init_required_params(config)

        # use the grid dimension when setting the lower bound
        lower_bound = torch.tensor([-15] * self.grid_dimension, device=config["device"])
        upper_bound = torch.tensor([15] * self.grid_dimension, device=config["device"])

        # Generate the means for the corners of the hypercube
        means = torch.zeros((2**self.grid_dimension, self.grid_dimension), device=config["device"])
        for i in range(2**self.grid_dimension):
            binary = bin(i)[2:].zfill(self.grid_dimension)
            for j in range(self.grid_dimension):
                if binary[j] == '0':
                    means[i][j] = -self.edge_size
                else:
                    means[i][j] = self.edge_size

        # Create the mixture of multivariate normal distributions
        self.mixture = [
            torch.distributions.MultivariateNormal(means[i], torch.eye(self.grid_dimension, device=config["device"])) for i in range(2**self.grid_dimension)
        ]
        super().__init__(config,
                    dim = self.grid_dimension,
                    feature_dim = self.grid_dimension,
                    angle_dim = [False] * self.grid_dimension,
                    action_dim = self.grid_dimension,
                    lower_bound = lower_bound,
                    upper_bound = upper_bound,
                    mixture_dim = config["env"]["mixture_dim"],
                    output_dim = (2 * self.grid_dimension + 1) * config["env"]["mixture_dim"])  

    def _init_required_params(self, config):
        required_params = ["max_policy_std", "min_policy_std", "grid_dimension"]
        assert all([param in config["env"] for param in required_params]), f"Missing required parameters: {required_params}"
        self.max_policy_std = config["env"]["max_policy_std"]
        self.min_policy_std = config["env"]["min_policy_std"]
        self.grid_dimension = config["env"]["grid_dimension"]
        self.edge_size = config["env"]["edge_size"]

    def log_reward(self, x):
        return torch.logsumexp(torch.stack([m.log_prob(x) for m in self.mixture], 0), 0)
    
    def step(self, x: Tensor, action: Tensor):
        """Takes a step in the environment given an action. x is the current state and action is the action to take. Returns the new state."""
        new_x = torch.zeros_like(x)
        # update the N coordinates
        new_x[:, :-1] = x[:, :-1] + action
        # increment the step counter
        new_x[:, self.grid_dimension] = x[:, self.grid_dimension] + 1

        return new_x
    
    def backward_step(self, x: Tensor, action: Tensor):
        """Takes a backward step in the environment given an action. x is the current state and action that had been taken to reach x. Returns the previous state."""
        new_x = torch.zeros_like(x)
        # update the N coordinates
        new_x[:, :-1] = x[:, :-1] - action 
        # increment the step counter
        new_x[:, -1] = x[:, -1] - 1            

        return new_x
    
    def compute_initial_action(self, first_state):
        return (first_state - self.init_value)
    
    def _init_policy_dist(self, param_dict):
        """Initialises a mixture of von Mises distributions. Used for policy parameterisation."""
        mus, sigmas, weights = param_dict["mus"], param_dict["sigmas"], param_dict["weights"]
        # reshape to batch_size x mixture_dim x grid_dimension
        mus = mus.view(-1, self.mixture_dim, self.grid_dimension)
        sigmas = sigmas.view(-1, self.mixture_dim, self.grid_dimension)
        covs = torch.diag_embed(sigmas)

        # Define the mixture components
        mix = torch.distributions.Categorical(weights)
        components = torch.distributions.MultivariateNormal(mus, covariance_matrix=covs)

        # Combine into a MixtureSameFamily distribution
        return torch.distributions.MixtureSameFamily(mix, components)
    
    def postprocess_params(self, params):
        """Postprocesses the parameters of the policy distribution to ensure they are within the correct range(s)."""
        # Restrict mu_x and mu_y to the range (-pi, pi)
        mus_params, sigmas_params, weight_params = params[:, :self.grid_dimension * self.mixture_dim], params[:, self.grid_dimension * self.mixture_dim: 2 * self.grid_dimension * self.mixture_dim], params[:, 2 * self.grid_dimension * self.mixture_dim: 3 * self.grid_dimension * self.mixture_dim]

        mus = torch.sigmoid(mus_params) * (2 * self.edge_size) - self.edge_size
        sigmas = torch.sigmoid(sigmas_params) * (self.max_policy_std - self.min_policy_std) + self.min_policy_std

        weights = torch.softmax(weight_params, dim=1)
        param_dict = {"mus": mus, "sigmas": sigmas, "weights": weights}
        
        return param_dict
    
    def add_noise(self, param_dict: dict, off_policy_noise: float):
        """Adds noise to the policy parameters for noisy exploration."""
        param_dict["sigmas"] += off_policy_noise

        return param_dict
    